from torch.optim import SGD
from .sam import SAM
from .sgdgr import SGDGR
from .saner import SANER
from .fsam import FSAM
from .fsamsaner import FSAMSANER
from .vasso import VASSO
from .vassosaner import VASSOSANER
from .gsam import GSAM
from .gsamsaner import GSAMSANER

def get_optimizer(
    net,
    opt_name='sam',
    opt_hyperpara={}):
    if opt_name == 'sam':
        return SAM(net.parameters(), **opt_hyperpara)
    elif opt_name == 'sgd':
        return SGD(net.parameters(), **opt_hyperpara)
    elif opt_name == 'sgdgr':
        return SGDGR(net.parameters(), **opt_hyperpara)
    elif opt_name == 'SANER':
        return SANER(net.parameters(), **opt_hyperpara)
    elif opt_name == 'fsam':
        return FSAM(net.parameters(), **opt_hyperpara)
    elif opt_name == 'fsamsaner':
        return FSAMSANER(net.parameters(), **opt_hyperpara)
    elif opt_name == 'vasso':
        return VASSO(net.parameters(), **opt_hyperpara)
    elif opt_name == 'vassosaner':
        return VASSOSANER(net.parameters(), **opt_hyperpara)
    elif opt_name == 'gsam':
        return GSAM(net.parameters(), **opt_hyperpara)
    elif opt_name == 'gsamsaner':
        return GSAMSANER(net.parameters(), **opt_hyperpara)
    else:
        raise ValueError("Invalid optimizer!!!")